#ifndef GRADIENT_GENERATOR_H
#define GRADIENT_GENERATOR_H

#include "Loss_generator.h"
#include <Windows.h>
#pragma warning(disable: n_label996)

class Gradient_generator : public Loss_generator {
	private:
		struct app1_gradient* gradient;
	public:
		Gradient_generator(problem, NN_data*, User_parameter*, struct app1_indata*, struct app1_solution*, struct app1_gradient*);
		void calc_grad_expected_demand(void);
		void clear(void);
		~Gradient_generator() {  }
};

Gradient_generator::Gradient_generator(problem pr, NN_data* nn_data, User_parameter* user_parameter, struct app1_indata* indata, struct app1_solution* solution, struct app1_gradient* gradient) : Loss_generator(pr, nn_data, user_parameter, indata, solution) {
	this->gradient = gradient;

}

void Gradient_generator::calc_grad_expected_demand() {
	for (int l = 0;l < n_label;l++)
		for (int i = 0;i < nn_data->n_sample;i++)
		{
			gradient->batch_expected_gradient[l][i] += (nn_data->batch_expected[l][i] - nn_data->batch_ans[l][i]) * (1+ user_parameter->th*(nn_data->batch_expected[l][i] > nn_data->batch_ans[l][i]));
		}

}




void Gradient_generator::clear() {
	memset(gradient->ccon_gradient_batch, 0, sizeof(gradient->ccon_gradient_batch));
	memset(gradient->test_case_ccon_gradient, 0, sizeof(gradient->test_case_ccon_gradient));
}




#endif